package generator

import (
	"bytes"
	"fmt"
	"html/template"
	"io"

	"gitee.com/lipore/plume/db_gorm/gen/code"
	"gitee.com/lipore/plume/db_gorm/gen/spec"
	"golang.org/x/tools/imports"
)

// GenerateRepository generates repository implementation from repository interface specification
func GenerateRepository(persistPackage, domainPackage string, persistStructModel, domainStructModel code.Struct,
	interfaceName string, methodSpecs []spec.MethodSpec, generator RepositoryGenerator, genMapper, genConstructor bool) (string, error) {

	repositoryGenerator := repositoryGenerator{
		PersistPackageName:    persistPackage,
		DomainPackageName:     domainPackage,
		PersistStructModel:    persistStructModel,
		DomainStructModel:     domainStructModel,
		InterfaceName:         interfaceName,
		MethodSpecs:           methodSpecs,
		Generator:             generator,
		GenerateObjectMappter: genMapper,
		GenerateConstructor:   genConstructor,
	}

	return repositoryGenerator.Generate()
}

type repositoryGenerator struct {
	PersistPackageName    string
	DomainPackageName     string
	PersistStructModel    code.Struct
	DomainStructModel     code.Struct
	InterfaceName         string
	MethodSpecs           []spec.MethodSpec
	GenerateObjectMappter bool
	GenerateConstructor   bool
	Generator             RepositoryGenerator
}

type RepositoryGenerator interface {
	GenerateConstructor(buffer io.Writer) error
	GenerateMethod(methodSpec spec.MethodSpec, buffer io.Writer) error
}

func (g repositoryGenerator) Generate() (string, error) {
	buffer := new(bytes.Buffer)
	if err := g.generateBase(buffer); err != nil {
		return "", err
	}

	if g.GenerateConstructor {
		if err := g.Generator.GenerateConstructor(buffer); err != nil {
			return "", err
		}
	}

	if g.GenerateObjectMappter {
		if err := g.generateObjectMapper(buffer); err != nil {
			return "", err
		}
	}

	for _, method := range g.MethodSpecs {
		if err := g.Generator.GenerateMethod(method, buffer); err != nil {
			return "", err
		}
	}

	formattedCode, err := imports.Process("", buffer.Bytes(), nil)
	if err != nil {
		fmt.Println(buffer.String())
		return "", err
	}

	return string(formattedCode), nil
}

func (g repositoryGenerator) generateBase(buffer *bytes.Buffer) error {
	tmpl, err := template.New("file_base").Parse(baseTemplate)
	if err != nil {
		return err
	}

	tmplData := baseTemplateData{
		PackageName: g.PersistPackageName,
	}

	if err := tmpl.Execute(buffer, tmplData); err != nil {
		return err
	}

	return nil
}

func (g repositoryGenerator) generateObjectMapper(buffer io.Writer) error {
	tmpl, err := template.New("objectMapper").Parse(objectMapper)
	if err != nil {
		return err
	}
	fieldMap := make([]objectMapperField, 0)
	for _, entityField := range g.PersistStructModel.Fields {
		if mapper, ok := entityField.Tags["mapper"]; ok && len(mapper) == 3 {
			domainFieldName := mapper[0]
			domain2EntityFunc := mapper[1]
			entity2DomainFunc := mapper[2]
			for _, domainField := range g.DomainStructModel.Fields {
				if domainFieldName == domainField.Name {
					fieldMap = append(fieldMap, objectMapperField{
						DomainField:       domainFieldName,
						EntityField:       entityField.Name,
						Entity2DomainFunc: entity2DomainFunc,
						Domain2EntityFunc: domain2EntityFunc,
						NeedConvert:       true,
					})
				}
			}
		} else {
			for _, domainField := range g.DomainStructModel.Fields {
				if entityField.Name == domainField.Name {
					fieldMap = append(fieldMap, objectMapperField{
						DomainField:       domainField.Name,
						Entity2DomainFunc: domainField.Type.Code(),
						EntityField:       entityField.Name,
						Domain2EntityFunc: entityField.Type.Code(),
						NeedConvert:       domainField.Type.Code() != entityField.Type.Code(),
					})
					break
				}
			}
		}
	}
	tmplData := objectMapperData{
		EntityName:       g.PersistStructModel.Name,
		DomainObjectCode: fmt.Sprintf("%s", g.DomainStructModel.ReferencedType().Code()),
		FieldMap:         fieldMap,
	}

	if err := tmpl.Execute(buffer, tmplData); err != nil {
		return err
	}
	return nil
}
